package com.uziot.bucket.common.recursivetask;

import com.uziot.bucket.common.util.sequence.UUID;
import lombok.extern.slf4j.Slf4j;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;


/**
 * @author shidt
 * @version V1.0
 * @className RecursiveTaskTemplate
 * @date 2019-12-20 08:51:35
 * @description 有返回值的工作窃取算法
 */

@Slf4j
public class RecursiveTaskTemplate extends RecursiveTask<List<Object>> {

    /**
     * 任务拆分临界点,应大于1的整数
     */
    private static final Integer MAX_LIMIT = 20;
    private static final long serialVersionUID = 2531295702633789662L;

    private final Integer start;
    private final Integer end;
    private final List<Object> tasks;

    public RecursiveTaskTemplate(Integer start, Integer end, List<Object> tasks) {
        this.start = start;
        this.end = end;
        this.tasks = tasks;
    }

    @Override
    protected List<Object> compute() {
        //任务足够小（类似递归）
        if ((end - start) < MAX_LIMIT) {
            // 分段返回值
            List<Object> results = new ArrayList<>();
            // 取得任务列表
            List<Object> taskList = tasks.subList(start, end);
            log.info("当前取得任务阶段数量：[{}]", end - start);
            for (Object obj : taskList) {
                log.info("正在执行任务，取得任务为：{}", obj);
                results.add(obj);
            }
            return results;
        } else {
            //拆分任务
            int middle = (end + start) / 2;
            RecursiveTaskTemplate leftTask = new RecursiveTaskTemplate(start, middle, tasks);
            RecursiveTaskTemplate rightTask = new RecursiveTaskTemplate(middle, end, tasks);

            // 加入任务
            leftTask.fork();
            rightTask.fork();

            // 汇总各个分任务节点返回值
            List<Object> leftList = leftTask.join();
            List<Object> rightList = rightTask.join();
            leftList.addAll(rightList);
            return leftList;
        }
    }


    public static void main(String[] args) {
        ArrayList<Object> tasks = new ArrayList<>();
        for (int i = 0; i < 100; i++) {
            tasks.add(UUID.fastUUID());
        }
        RecursiveTaskTemplate task = new RecursiveTaskTemplate(0, tasks.size(), tasks);

        ForkJoinPool forkjoinPool = new ForkJoinPool();
        Future<List<Object>> result = forkjoinPool.submit(task);

        List<Object> executeResult;
        try {
            executeResult = result.get();
            log.info("----------------执行完成结果[" + executeResult.size() + "]如下--------------");
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            forkjoinPool.shutdown();
        }
    }
}